seed = 42
num_envs = 4

# network
batch_first = True
patch_size = 256
hidden_dim = 64
output_dim = 1
middle_dim = 32
embed_dim = 64
dropout = 0.2

# hyperparameters
alpha = 0.05
action_conf = 0.95
gamma = 0.002
gae_lambda = 0.95
clip_coef = 0.2
vf_coef = 0.5
entropy_coef = 0.01
max_grad_norm = 0.5
target_kl = 0.1
calibration_size = 10


# training
policy_learning_rate = 5e-5
value_learning_rate = 5e-5
num_epochs = 1000
num_updates = 1000
num_steps = 128 # 512
policy_num_minibatches = 128
value_num_minibatches = 16
batch_size = min(int(num_envs * num_steps), 8192)
policy_minibatch_size = max(32, int(batch_size // policy_num_minibatches))
value_minibatch_size = max(32, int(batch_size // value_num_minibatches))
lookback = 200
retrain_gap = 1
update_alpha_gap = 1


gradient_checkpointing_steps = 32
total_timesteps = int(1e8)
check_steps = int(1e4)
save_steps = int(1e5)
train_ratio = 0.7
clip_vloss = False

score_function = "CQR"
quantile = [0.05, 0.5, 0.95]
# [0.05, 0.5, 0.95], [0.05, 0.35, 0.65, 0.95], [0.05, 0.25, 0.5, 0.75, 0.95], [0.05, 0.2, 0.35, 0.5, 0.65, 0.8, 0.95]

# dataset
file_path = "PATH TO DATASET"
method = "transformer"      # "mlp", "transformer"

# MLP and Transformer parameters
depth = 4
# Transformer parameters
group_size = 4
num_head = 16

### ------------------ weather ------------------
# input_dim = 21
# dataset = 'weather'
# window_size = 50
# output_window_length = 336
## 96 192 336 720
### ------------------ traffic ------------------
# input_dim = 862
# dataset = 'traffic'
# window_size = 50
# output_window_length = 720
## 96 192 336 720
### ------------------ illness ------------------
input_dim = 7
dataset = 'illness'
window_size = 50
output_window_length = 24
## 24 36 48 60
### ------------------ electricity ------------------
# input_dim = 321
# dataset = 'electricity'
# window_size = 50
# output_window_length = 720
## 96 192 336 720
## ------------------ ETTh1 ------------------
# input_dim = 7
# dataset = 'ETTh1'
# window_size = 50
# output_window_length = 96
## 96 192 336 720
## ------------------ ETTh2 ------------------
# input_dim = 7
# dataset = 'ETTh2'
# window_size = 50
# output_window_length = 192
## 96 192 336 720
## ------------------ ETTm1 ------------------
# input_dim = 7
# dataset = 'ETTm1'
# window_size = 50
# output_window_length = 720
## 96 192 336 720
## ------------------ ETTm2 ------------------
# input_dim = 7
# dataset = 'ETTm2'
# window_size = 50
# output_window_length = 720
## 96 192 336 720


# logger
tensorboard = True
wandb = True
project = f"conformal_prediction"
tag = f"{dataset}_out{output_window_length}_{method}"

# path
work_dir = "PATH TO WORKDIR"
checkpoint_path = work_dir + f"core_{dataset}/checkpoint/"
tensorboard_path = work_dir + f"core_{dataset}/tensorboard/"
wandb_path = work_dir + f"core_{dataset}/wandb/"
log_path = work_dir + f"core_{dataset}/results.log"
fig_path = work_dir + f"core_{dataset}/figs/"


# environment
indicator_transition = ["features"] # indicator states

# Action-dependent data
policy_transition = [
    "policy_actions",
    "policy_values",
    "policy_logprobs",
    "policy_rewards",
]

# Regression-specific data
regression_transition = [
    "predictions",
    "targets",
]

# Training-related data
training_transition = [
    "training_loss",
    "training_rewards",
    "training_values",
    "training_advantages",
    "training_returns",
    "training_actions",
    "training_dones",
    "training_logprobs",
]

# Final transition list
transition = indicator_transition + policy_transition + regression_transition + training_transition
transition_shape = dict(
    features=dict(shape=(num_envs, window_size, input_dim), type="float32",
                  low = -float("inf"), high=float("inf"), obs = True),
    predictions=dict(shape=(num_envs, output_dim), type="float32",
                     low=-float("inf"), high=float("inf"), obs=True),
    targets=dict(shape=(num_envs, output_dim), type="float32",
                 low=-float("inf"), high=float("inf"), obs=True),

    training_actions=dict(shape=(num_envs, output_dim), type="float32", obs=False),
    training_dones=dict(shape=(num_envs,), type="float32", obs=False),
    training_logprobs=dict(shape=(num_envs,), type="float32", obs=False),
    training_loss=dict(shape=(num_envs,), type="float32", obs=False),
    training_rewards=dict(shape = (num_envs, ), type="float32", obs=False),
    training_values=dict(shape=(num_envs, ), type="float32", obs=False),
    training_advantages=dict(shape=(num_envs, ), type="float32", obs=False),
    training_returns=dict(shape=(num_envs, ), type="float32", obs=False),

    policy_actions=dict(shape=(num_envs,), type="float32", obs=True),
    policy_values=dict(shape=(num_envs,), type="float32", obs=True),
    policy_logprobs=dict(shape=(num_envs,), type="float32", obs=True),
    policy_rewards=dict(shape=(num_envs,), type="float32", obs=True),
)